"""
Extensive Margin Puzzle Analysis for Gravity Bilateral Paper.

Key puzzle: Extensive margin logit shows ΔZ₁ = -2.896 (p<0.001) — aging origin
means FEWER bilateral connections. But intensive margin GLS shows ΔZ₁ = +0.815
(p=0.208) — aging origin means LARGER positions (NS). The signs are opposite.

This script investigates:
1. Reporter-level correlations between Z_1 and # connections / avg position size
2. Extensive margin with KAOPEN interactions
3. OECD vs non-OECD split
4. Structural zeros vs genuine zeros decomposition

Output: extensive_margin_analysis.md
"""

import pandas as pd
import numpy as np
from pathlib import Path
from scipy import stats
import statsmodels.api as sm
import warnings
warnings.filterwarnings('ignore')

BASE_DIR = Path("/mnt/c/demographics_capital_flows/gravity_bilateral_followup")
PROCESSED_DIR = BASE_DIR / "data" / "processed"
OUTPUT_DIR = BASE_DIR / "output" / "tables"

# OECD members (as of ~2024, using ISO3 codes for reporter field)
OECD_ISO3 = {
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK', 'EST',
    'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR', 'ITA', 'JPN',
    'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL', 'NOR', 'POL', 'PRT',
    'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR', 'GBR', 'USA'
}


def load_data():
    """Load bilateral panel."""
    print("Loading bilateral panel...")
    df = pd.read_csv(PROCESSED_DIR / "bilateral_panel.csv")
    print(f"  Loaded: {len(df):,} obs, {df['reporter'].nunique()} reporters, "
          f"{df['partner'].nunique()} partners, years {int(df['year'].min())}-{int(df['year'].max())}")
    return df


def section1_reporter_level_correlations(df):
    """
    For each reporter-year, compute:
      - n_connections: number of positive bilateral portfolio links
      - avg_position: mean position size (among positive links)
      - total_position: sum of portfolio holdings
    Then correlate with reporter's Z_1.
    """
    print("\n" + "=" * 70)
    print("SECTION 1: REPORTER-LEVEL CORRELATIONS")
    print("=" * 70)

    output = []
    output.append("## Section 1: Reporter-Level Correlations")
    output.append("")
    output.append("For each reporter-year, we compute the number of positive bilateral")
    output.append("portfolio connections (extensive) and the average/total position size")
    output.append("(intensive), then correlate with the reporter's Z₁.")
    output.append("")

    # Compute reporter-year level stats
    ry = df.groupby(['reporter', 'year']).agg(
        n_total_pairs=('partner', 'count'),
        n_positive=('has_portfolio_total', 'sum'),
        total_position=('portfolio_total', 'sum'),
        Z_1=('Z_1_i', 'first'),
        kaopen=('kaopen_i', 'first'),
        gdp_pc_ppp=('gdp_pc_ppp_i', 'first'),
        ngdp_usd=('ngdp_usd_i', 'first'),
    ).reset_index()

    # Average position among positive links
    pos_df = df[df['has_portfolio_total'] == 1].groupby(['reporter', 'year']).agg(
        avg_position=('portfolio_total', 'mean'),
        median_position=('portfolio_total', 'median'),
    ).reset_index()

    ry = ry.merge(pos_df, on=['reporter', 'year'], how='left')
    ry['connection_rate'] = ry['n_positive'] / ry['n_total_pairs']
    ry['log_avg_position'] = np.log(ry['avg_position'].clip(lower=1))
    ry['log_n_positive'] = np.log(ry['n_positive'].clip(lower=1))

    # Only keep reporters with at least some connections
    ry_pos = ry[ry['n_positive'] > 0].copy()

    print(f"  Reporter-years total: {len(ry):,}")
    print(f"  Reporter-years with >0 connections: {len(ry_pos):,}")

    # Cross-sectional correlations (pooled)
    vars_to_corr = [
        ('n_positive', 'Number of connections'),
        ('connection_rate', 'Connection rate (extensive margin)'),
        ('log_avg_position', 'Log avg position size (intensive margin)'),
        ('total_position', 'Total portfolio holdings'),
    ]

    output.append("### Pooled Correlations: Z₁ vs Reporter-Level Outcomes")
    output.append("")
    output.append("| Outcome | N | Correlation | p-value |")
    output.append("|---------|---|-------------|---------|")

    for var, label in vars_to_corr:
        subset = ry_pos.dropna(subset=['Z_1', var])
        if len(subset) > 10:
            r, p = stats.pearsonr(subset['Z_1'], subset[var])
            print(f"  Z_1 vs {label}: r={r:.4f}, p={p:.4f}, N={len(subset)}")
            output.append(f"| {label} | {len(subset):,} | {r:.4f} | {p:.4f} |")

    output.append("")

    # OLS: n_positive ~ Z_1 + log(GDP) + year FE
    print("\n  OLS regressions (reporter-year level):")
    output.append("### OLS Regressions: Reporter-Year Level")
    output.append("")

    for dep_var, dep_label in [('n_positive', 'N connections'),
                                ('connection_rate', 'Connection rate'),
                                ('log_avg_position', 'Log avg position')]:
        est = ry_pos.dropna(subset=[dep_var, 'Z_1', 'ngdp_usd']).copy()
        est['log_gdp'] = np.log(est['ngdp_usd'].clip(lower=1))

        # Year dummies
        years = sorted(est['year'].unique())
        for y in years[1:]:
            est[f'yr_{int(y)}'] = (est['year'] == y).astype(float)
        yr_cols = [f'yr_{int(y)}' for y in years[1:]]

        X = sm.add_constant(est[['Z_1', 'log_gdp'] + yr_cols])
        y = est[dep_var]

        result = sm.OLS(y, X).fit(cov_type='cluster', cov_kwds={'groups': est['reporter']})

        z1_coef = result.params['Z_1']
        z1_se = result.bse['Z_1']
        z1_p = result.pvalues['Z_1']
        sig = '***' if z1_p < 0.01 else '**' if z1_p < 0.05 else '*' if z1_p < 0.1 else ''

        print(f"    {dep_label}: Z_1 = {z1_coef:.4f} (se={z1_se:.4f}, p={z1_p:.4f}) {sig}, "
              f"R²={result.rsquared:.4f}, N={result.nobs:.0f}")

        output.append(f"**{dep_label}**: Z₁ = {z1_coef:.4f} (SE = {z1_se:.4f}, p = {z1_p:.4f}{sig}), "
                      f"R² = {result.rsquared:.4f}, N = {result.nobs:.0f}")

    output.append("")

    # Key insight: do reporters with MORE connections have SMALLER avg positions?
    subset = ry_pos.dropna(subset=['n_positive', 'log_avg_position'])
    r_conn_size, p_conn_size = stats.pearsonr(subset['n_positive'], subset['log_avg_position'])
    print(f"\n  Correlation: N connections vs log avg position: r={r_conn_size:.4f}, p={p_conn_size:.4f}")
    output.append(f"**Cross-check**: Correlation between N connections and log avg position: "
                  f"r = {r_conn_size:.4f} (p = {p_conn_size:.4f})")
    output.append("")

    return output, ry, ry_pos


def section2_extensive_kaopen(df):
    """
    Test extensive margin with KAOPEN interactions.
    Logit: has_portfolio_total ~ gravity + dZ_1 + dZ_1*kaopen_j + kaopen_j
    """
    print("\n" + "=" * 70)
    print("SECTION 2: EXTENSIVE MARGIN WITH KAOPEN INTERACTIONS")
    print("=" * 70)

    output = []
    output.append("## Section 2: Extensive Margin with KAOPEN Interactions")
    output.append("")
    output.append("Does partner financial openness (KAOPEN_j) moderate the demographic")
    output.append("effect on connection formation?")
    output.append("")

    gravity_vars = ['log_dist', 'contiguity', 'common_lang_official', 'colonial_ties', 'log_gdp_product']
    demo_vars = ['dZ_1', 'dZ_2', 'dZ_3']
    kaopen_vars = ['kaopen_j', 'dZ_1_x_kaopen_j', 'dZ_2_x_kaopen_j', 'dZ_3_x_kaopen_j']

    # Model A: Extensive without KAOPEN
    regressors_a = gravity_vars + demo_vars
    # Model B: Extensive with KAOPEN
    regressors_b = gravity_vars + demo_vars + kaopen_vars

    for model_name, regressors in [("A: Extensive (no KAOPEN)", regressors_a),
                                    ("B: Extensive (with KAOPEN)", regressors_b)]:
        est = df.dropna(subset=['has_portfolio_total'] + regressors).copy()
        y = est['has_portfolio_total'].values
        X = sm.add_constant(est[regressors].values)

        try:
            logit = sm.Logit(y, X)
            result = logit.fit(disp=0, maxiter=100)

            print(f"\n  {model_name}: N={len(est):,}, Pseudo-R²={result.prsquared:.4f}")

            output.append(f"### {model_name}")
            output.append(f"N = {len(est):,}, Pseudo-R² = {result.prsquared:.4f}")
            output.append("")
            output.append("| Variable | Coefficient | SE | p-value |")
            output.append("|----------|-------------|-----|---------|")

            for i, v in enumerate(regressors):
                coef = result.params[i + 1]
                se = result.bse[i + 1]
                pv = result.pvalues[i + 1]
                sig = '***' if pv < 0.01 else '**' if pv < 0.05 else '*' if pv < 0.1 else ''
                print(f"    {v:<30} {coef:>10.4f} {se:>10.4f} {pv:>8.4f} {sig}")
                output.append(f"| {v} | {coef:.4f} | {se:.4f} | {pv:.4f}{sig} |")

            output.append("")

        except Exception as e:
            print(f"  {model_name}: FAILED - {e}")
            output.append(f"**{model_name}**: Failed - {e}")
            output.append("")

    # Also test: does kaopen_i (origin openness) matter for extensive margin?
    print("\n  Testing origin KAOPEN on extensive margin:")
    regressors_c = gravity_vars + demo_vars + ['kaopen_i']
    # Add interaction: dZ_1 * kaopen_i
    df_ext = df.copy()
    df_ext['dZ_1_x_kaopen_i'] = df_ext['dZ_1'] * df_ext['kaopen_i']
    regressors_c2 = gravity_vars + demo_vars + ['kaopen_i', 'dZ_1_x_kaopen_i']

    est = df_ext.dropna(subset=['has_portfolio_total'] + regressors_c2).copy()
    y = est['has_portfolio_total'].values
    X = sm.add_constant(est[regressors_c2].values)

    try:
        result = sm.Logit(y, X).fit(disp=0, maxiter=100)
        print(f"  C: Extensive (origin KAOPEN): N={len(est):,}, Pseudo-R²={result.prsquared:.4f}")

        output.append("### C: Extensive with Origin KAOPEN Interaction")
        output.append(f"N = {len(est):,}, Pseudo-R² = {result.prsquared:.4f}")
        output.append("")
        output.append("| Variable | Coefficient | SE | p-value |")
        output.append("|----------|-------------|-----|---------|")

        for i, v in enumerate(regressors_c2):
            coef = result.params[i + 1]
            se = result.bse[i + 1]
            pv = result.pvalues[i + 1]
            sig = '***' if pv < 0.01 else '**' if pv < 0.05 else '*' if pv < 0.1 else ''
            print(f"    {v:<30} {coef:>10.4f} {se:>10.4f} {pv:>8.4f} {sig}")
            output.append(f"| {v} | {coef:.4f} | {se:.4f} | {pv:.4f}{sig} |")

        output.append("")

    except Exception as e:
        print(f"  Model C failed: {e}")

    return output


def section3_oecd_split(df):
    """
    Split by reporter development level: OECD vs non-OECD.
    Run extensive margin logit and intensive margin GLS separately.
    """
    print("\n" + "=" * 70)
    print("SECTION 3: OECD vs NON-OECD SPLIT")
    print("=" * 70)

    output = []
    output.append("## Section 3: OECD vs Non-OECD Split")
    output.append("")
    output.append("Does the extensive margin puzzle differ by development level?")
    output.append("OECD = reporter's ISO3 is in OECD member list.")
    output.append("")

    # Classify reporters
    df['is_oecd_i'] = df['iso_o'].isin(OECD_ISO3).astype(int)
    n_oecd = df[df['is_oecd_i'] == 1]['reporter'].nunique()
    n_non = df[df['is_oecd_i'] == 0]['reporter'].nunique()
    print(f"  OECD reporters: {n_oecd}, Non-OECD reporters: {n_non}")

    gravity_vars = ['log_dist', 'contiguity', 'common_lang_official', 'colonial_ties', 'log_gdp_product']
    demo_vars = ['dZ_1', 'dZ_2', 'dZ_3']

    output.append(f"OECD reporters: {n_oecd}, Non-OECD reporters: {n_non}")
    output.append("")

    for group_name, mask in [("OECD", df['is_oecd_i'] == 1),
                              ("Non-OECD", df['is_oecd_i'] == 0)]:
        sub = df[mask].copy()

        output.append(f"### {group_name} Reporters")
        output.append("")

        # --- Extensive margin (logit) ---
        regressors = gravity_vars + demo_vars
        est_ext = sub.dropna(subset=['has_portfolio_total'] + regressors).copy()
        y_ext = est_ext['has_portfolio_total'].values
        X_ext = sm.add_constant(est_ext[regressors].values)

        try:
            result_ext = sm.Logit(y_ext, X_ext).fit(disp=0, maxiter=100)
            print(f"\n  {group_name} Extensive: N={len(est_ext):,}, Pseudo-R²={result_ext.prsquared:.4f}")

            output.append(f"**Extensive Margin (Logit)**: N = {len(est_ext):,}, "
                          f"Pseudo-R² = {result_ext.prsquared:.4f}")
            output.append("")
            output.append("| Variable | Coefficient | SE | p-value |")
            output.append("|----------|-------------|-----|---------|")

            for i, v in enumerate(regressors):
                coef = result_ext.params[i + 1]
                se = result_ext.bse[i + 1]
                pv = result_ext.pvalues[i + 1]
                sig = '***' if pv < 0.01 else '**' if pv < 0.05 else '*' if pv < 0.1 else ''
                if v.startswith('dZ'):
                    print(f"    {v:<30} {coef:>10.4f} {se:>10.4f} {pv:>8.4f} {sig}")
                output.append(f"| {v} | {coef:.4f} | {se:.4f} | {pv:.4f}{sig} |")

            output.append("")

        except Exception as e:
            print(f"  {group_name} Extensive: FAILED - {e}")
            output.append(f"Extensive margin failed: {e}")
            output.append("")

        # --- Intensive margin (OLS on log positions) ---
        est_int = sub.dropna(subset=['log_portfolio_total'] + regressors).copy()
        if len(est_int) > 100:
            y_int = est_int['log_portfolio_total'].values
            X_int = sm.add_constant(est_int[regressors].values)

            result_int = sm.OLS(y_int, X_int).fit()
            print(f"  {group_name} Intensive: N={len(est_int):,}, R²={result_int.rsquared:.4f}")

            output.append(f"**Intensive Margin (OLS, log portfolio)**: N = {len(est_int):,}, "
                          f"R² = {result_int.rsquared:.4f}")
            output.append("")
            output.append("| Variable | Coefficient | SE | p-value |")
            output.append("|----------|-------------|-----|---------|")

            for i, v in enumerate(regressors):
                coef = result_int.params[i + 1]
                se = result_int.bse[i + 1]
                pv = result_int.pvalues[i + 1]
                sig = '***' if pv < 0.01 else '**' if pv < 0.05 else '*' if pv < 0.1 else ''
                if v.startswith('dZ'):
                    print(f"    {v:<30} {coef:>10.4f} {se:>10.4f} {pv:>8.4f} {sig}")
                output.append(f"| {v} | {coef:.4f} | {se:.4f} | {pv:.4f}{sig} |")

            output.append("")

        # --- Reporter-level stats ---
        ry = sub.groupby(['reporter', 'year']).agg(
            n_positive=('has_portfolio_total', 'sum'),
            n_total=('partner', 'count'),
            Z_1=('Z_1_i', 'first'),
        ).reset_index()
        ry['conn_rate'] = ry['n_positive'] / ry['n_total']

        ry_pos = ry[ry['n_positive'] > 0]
        if len(ry_pos) > 10:
            r, p = stats.pearsonr(ry_pos['Z_1'], ry_pos['n_positive'])
            r2, p2 = stats.pearsonr(ry_pos['Z_1'], ry_pos['conn_rate'])
            print(f"  {group_name}: Z_1 vs n_connections: r={r:.4f} (p={p:.4f})")
            print(f"  {group_name}: Z_1 vs connection_rate: r={r2:.4f} (p={p2:.4f})")
            output.append(f"Reporter-level: Z₁ vs N connections: r = {r:.4f} (p = {p:.4f})")
            output.append(f"Reporter-level: Z₁ vs connection rate: r = {r2:.4f} (p = {p2:.4f})")
            output.append("")

    return output


def section4_zero_decomposition(df):
    """
    Among zero bilateral positions, distinguish:
    - Structural zeros: reporter does not participate in CPIS at all in that year
    - Genuine zeros: reporter IS in CPIS but holds zero in that partner
    """
    print("\n" + "=" * 70)
    print("SECTION 4: STRUCTURAL vs GENUINE ZEROS")
    print("=" * 70)

    output = []
    output.append("## Section 4: Structural vs Genuine Zeros Decomposition")
    output.append("")
    output.append("A **structural zero** occurs when a reporter does not participate in CPIS")
    output.append("at all (or has no positive holdings to any partner in that year).")
    output.append("A **genuine zero** means the reporter IS in CPIS but holds zero in that")
    output.append("specific partner.")
    output.append("")

    # For each reporter-year, check if they have ANY positive portfolio holding
    reporter_participation = df.groupby(['reporter', 'year']).agg(
        any_positive=('has_portfolio_total', 'max'),
        n_positive=('has_portfolio_total', 'sum'),
    ).reset_index()

    n_participating_ry = (reporter_participation['any_positive'] > 0).sum()
    n_nonparticipating_ry = (reporter_participation['any_positive'] == 0).sum()
    print(f"  Reporter-years participating in CPIS: {n_participating_ry}")
    print(f"  Reporter-years NOT participating: {n_nonparticipating_ry}")

    output.append(f"Reporter-years participating in CPIS (any positive holding): "
                  f"{n_participating_ry:,}")
    output.append(f"Reporter-years NOT participating: {n_nonparticipating_ry:,}")
    output.append("")

    # Merge participation flag back
    df_z = df.merge(reporter_participation[['reporter', 'year', 'any_positive']],
                    on=['reporter', 'year'], how='left')

    # Zeros breakdown
    zeros = df_z[df_z['has_portfolio_total'] == 0].copy()
    n_zeros_total = len(zeros)
    n_structural = (zeros['any_positive'] == 0).sum()  # reporter not in CPIS
    n_genuine = (zeros['any_positive'] > 0).sum()       # reporter in CPIS, just 0 for this partner

    pct_structural = n_structural / n_zeros_total * 100
    pct_genuine = n_genuine / n_zeros_total * 100

    print(f"\n  Total zero observations: {n_zeros_total:,}")
    print(f"  Structural zeros (reporter not in CPIS): {n_structural:,} ({pct_structural:.1f}%)")
    print(f"  Genuine zeros (reporter in CPIS, zero position): {n_genuine:,} ({pct_genuine:.1f}%)")

    output.append("### Zero Decomposition")
    output.append("")
    output.append("| Category | Count | Percentage |")
    output.append("|----------|-------|------------|")
    output.append(f"| Total zeros | {n_zeros_total:,} | 100.0% |")
    output.append(f"| Structural zeros (reporter not in CPIS) | {n_structural:,} | {pct_structural:.1f}% |")
    output.append(f"| Genuine zeros (reporter in CPIS, zero for this partner) | {n_genuine:,} | {pct_genuine:.1f}% |")
    output.append("")

    # Also positive obs for context
    n_positive = (df_z['has_portfolio_total'] == 1).sum()
    n_total = len(df_z)
    output.append(f"| Positive positions | {n_positive:,} | {n_positive/n_total*100:.1f}% of full matrix |")
    output.append("")

    # Does the extensive margin logit change when we restrict to genuine zeros only?
    print("\n  Re-running extensive logit on genuine-zeros-only sample...")
    genuine_sample = df_z[(df_z['any_positive'] > 0)].copy()  # reporters who participate

    gravity_vars = ['log_dist', 'contiguity', 'common_lang_official', 'colonial_ties', 'log_gdp_product']
    demo_vars = ['dZ_1', 'dZ_2', 'dZ_3']
    regressors = gravity_vars + demo_vars

    est = genuine_sample.dropna(subset=['has_portfolio_total'] + regressors).copy()
    y = est['has_portfolio_total'].values
    X = sm.add_constant(est[regressors].values)

    try:
        result = sm.Logit(y, X).fit(disp=0, maxiter=100)
        print(f"  Genuine-zeros logit: N={len(est):,}, Pseudo-R²={result.prsquared:.4f}")

        output.append("### Extensive Logit: CPIS Participants Only (excl structural zeros)")
        output.append(f"N = {len(est):,}, Pseudo-R² = {result.prsquared:.4f}")
        output.append("")
        output.append("| Variable | Coefficient | SE | p-value |")
        output.append("|----------|-------------|-----|---------|")

        for i, v in enumerate(regressors):
            coef = result.params[i + 1]
            se = result.bse[i + 1]
            pv = result.pvalues[i + 1]
            sig = '***' if pv < 0.01 else '**' if pv < 0.05 else '*' if pv < 0.1 else ''
            print(f"    {v:<30} {coef:>10.4f} {se:>10.4f} {pv:>8.4f} {sig}")
            output.append(f"| {v} | {coef:.4f} | {se:.4f} | {pv:.4f}{sig} |")

        output.append("")

        # Compare with full sample
        full_est = df.dropna(subset=['has_portfolio_total'] + regressors).copy()
        full_y = full_est['has_portfolio_total'].values
        full_X = sm.add_constant(full_est[regressors].values)
        full_result = sm.Logit(full_y, full_X).fit(disp=0, maxiter=100)

        idx_dz1 = regressors.index('dZ_1') + 1
        output.append("### Comparison: Full Sample vs CPIS Participants Only")
        output.append("")
        output.append("| | Full Sample | CPIS Participants Only |")
        output.append("|---|---|---|")
        output.append(f"| N | {len(full_est):,} | {len(est):,} |")
        output.append(f"| dZ₁ coef | {full_result.params[idx_dz1]:.4f} | {result.params[idx_dz1]:.4f} |")
        output.append(f"| dZ₁ p-value | {full_result.pvalues[idx_dz1]:.4f} | {result.pvalues[idx_dz1]:.4f} |")
        output.append(f"| Pseudo-R² | {full_result.prsquared:.4f} | {result.prsquared:.4f} |")
        output.append("")

    except Exception as e:
        print(f"  Genuine-zeros logit failed: {e}")

    # Demographics of structural vs genuine zeros
    print("\n  Demographics of zero types:")
    structural = df_z[(df_z['has_portfolio_total'] == 0) & (df_z['any_positive'] == 0)]
    genuine = df_z[(df_z['has_portfolio_total'] == 0) & (df_z['any_positive'] > 0)]

    for label, sub in [("Structural zeros", structural), ("Genuine zeros", genuine)]:
        z1_mean = sub['Z_1_i'].mean()
        z1_median = sub['Z_1_i'].median()
        gdp_mean = sub['gdp_pc_ppp_i'].mean()
        print(f"  {label}: mean Z_1_i={z1_mean:.4f}, median={z1_median:.4f}, mean GDP/cap={gdp_mean:.0f}")

    output.append("### Demographics of Zero Types")
    output.append("")
    output.append("| | Mean Z₁ (reporter) | Median Z₁ | Mean GDP/cap |")
    output.append("|---|---|---|---|")
    for label, sub in [("Structural zeros", structural),
                        ("Genuine zeros", genuine),
                        ("Positive positions", df_z[df_z['has_portfolio_total'] == 1])]:
        output.append(f"| {label} | {sub['Z_1_i'].mean():.4f} | {sub['Z_1_i'].median():.4f} | "
                      f"${sub['gdp_pc_ppp_i'].mean():,.0f} |")
    output.append("")

    # t-test: structural vs genuine Z_1
    t_stat, t_p = stats.ttest_ind(
        structural['Z_1_i'].dropna(),
        genuine['Z_1_i'].dropna(),
        equal_var=False
    )
    print(f"  t-test (structural vs genuine Z_1_i): t={t_stat:.4f}, p={t_p:.4f}")
    output.append(f"t-test structural vs genuine Z₁: t = {t_stat:.4f}, p = {t_p:.4f}")
    output.append("")

    return output


def section5_concentration_analysis(df):
    """
    Test whether aging countries concentrate holdings in fewer, larger positions
    (portfolio concentration / Herfindahl).
    """
    print("\n" + "=" * 70)
    print("SECTION 5: CONCENTRATION ANALYSIS")
    print("=" * 70)

    output = []
    output.append("## Section 5: Concentration Analysis")
    output.append("")
    output.append("Do aging countries concentrate their portfolio in fewer, larger positions?")
    output.append("We compute a Herfindahl index of portfolio allocation across partners.")
    output.append("")

    # For reporters with positive positions, compute HHI
    pos = df[df['has_portfolio_total'] == 1].copy()

    # Total by reporter-year
    totals = pos.groupby(['reporter', 'year'])['portfolio_total'].sum().reset_index()
    totals.columns = ['reporter', 'year', 'total_portfolio']

    pos = pos.merge(totals, on=['reporter', 'year'])
    pos['share'] = pos['portfolio_total'] / pos['total_portfolio']
    pos['share_sq'] = pos['share'] ** 2

    hhi = pos.groupby(['reporter', 'year']).agg(
        hhi=('share_sq', 'sum'),
        n_positions=('partner', 'count'),
        Z_1=('Z_1_i', 'first'),
        gdp_pc_ppp=('gdp_pc_ppp_i', 'first'),
        iso_o=('iso_o', 'first'),
    ).reset_index()

    hhi['log_hhi'] = np.log(hhi['hhi'].clip(lower=1e-10))

    # Correlations
    r_hhi, p_hhi = stats.pearsonr(hhi['Z_1'].dropna(), hhi.loc[hhi['Z_1'].notna(), 'hhi'])
    r_n, p_n = stats.pearsonr(hhi['Z_1'].dropna(), hhi.loc[hhi['Z_1'].notna(), 'n_positions'])

    print(f"  Z_1 vs HHI (concentration): r={r_hhi:.4f}, p={p_hhi:.4f}")
    print(f"  Z_1 vs N positions: r={r_n:.4f}, p={p_n:.4f}")

    output.append(f"Z₁ vs HHI (concentration): r = {r_hhi:.4f} (p = {p_hhi:.4f})")
    output.append(f"Z₁ vs N positions: r = {r_n:.4f} (p = {p_n:.4f})")
    output.append("")

    # OLS: HHI ~ Z_1 + log(GDP) + year FE, clustered by reporter
    est = hhi.dropna(subset=['hhi', 'Z_1', 'gdp_pc_ppp']).copy()
    est['log_gdp'] = np.log(est['gdp_pc_ppp'].clip(lower=1))
    years = sorted(est['year'].unique())
    for y in years[1:]:
        est[f'yr_{int(y)}'] = (est['year'] == y).astype(float)
    yr_cols = [f'yr_{int(y)}' for y in years[1:]]

    X = sm.add_constant(est[['Z_1', 'log_gdp'] + yr_cols])
    y = est['hhi']
    result = sm.OLS(y, X).fit(cov_type='cluster', cov_kwds={'groups': est['reporter']})

    z1_coef = result.params['Z_1']
    z1_p = result.pvalues['Z_1']
    sig = '***' if z1_p < 0.01 else '**' if z1_p < 0.05 else '*' if z1_p < 0.1 else ''
    print(f"  OLS HHI ~ Z_1: coef={z1_coef:.6f}, p={z1_p:.4f} {sig}, R²={result.rsquared:.4f}")

    output.append(f"OLS: HHI ~ Z₁ + log(GDP/cap) + year FE (reporter-clustered)")
    output.append(f"Z₁ coefficient = {z1_coef:.6f} (p = {z1_p:.4f}{sig}), R² = {result.rsquared:.4f}")
    output.append("")

    # OECD split
    for group, mask_val in [("OECD", True), ("Non-OECD", False)]:
        sub = est[est['iso_o'].isin(OECD_ISO3) == mask_val]
        if len(sub) > 50:
            X_s = sm.add_constant(sub[['Z_1', 'log_gdp'] + yr_cols])
            y_s = sub['hhi']
            res_s = sm.OLS(y_s, X_s).fit(cov_type='cluster', cov_kwds={'groups': sub['reporter']})
            sig_s = '***' if res_s.pvalues['Z_1'] < 0.01 else '**' if res_s.pvalues['Z_1'] < 0.05 else '*' if res_s.pvalues['Z_1'] < 0.1 else ''
            print(f"  {group}: Z_1→HHI = {res_s.params['Z_1']:.6f} (p={res_s.pvalues['Z_1']:.4f}) {sig_s}")
            output.append(f"{group}: Z₁→HHI = {res_s.params['Z_1']:.6f} "
                          f"(p = {res_s.pvalues['Z_1']:.4f}{sig_s})")

    output.append("")

    # Top 5 / Top 10 share
    pos_sorted = pos.sort_values(['reporter', 'year', 'portfolio_total'], ascending=[True, True, False])
    top5 = pos_sorted.groupby(['reporter', 'year']).head(5)
    top5_share = top5.groupby(['reporter', 'year'])['share'].sum().reset_index()
    top5_share.columns = ['reporter', 'year', 'top5_share']

    top5_merged = top5_share.merge(
        hhi[['reporter', 'year', 'Z_1', 'n_positions']],
        on=['reporter', 'year']
    )
    top5_with_z = top5_merged.dropna(subset=['Z_1', 'top5_share'])

    if len(top5_with_z) > 10:
        r_t5, p_t5 = stats.pearsonr(top5_with_z['Z_1'], top5_with_z['top5_share'])
        print(f"  Z_1 vs Top-5 share: r={r_t5:.4f}, p={p_t5:.4f}")
        output.append(f"Z₁ vs Top-5 partner share: r = {r_t5:.4f} (p = {p_t5:.4f})")
        output.append("")

    return output


def build_summary(all_output):
    """Build final markdown document."""
    lines = []
    lines.append("# Extensive Margin Puzzle Analysis")
    lines.append("")
    lines.append("## Key Puzzle")
    lines.append("")
    lines.append("The extensive margin (logit) shows dZ₁ = -2.896 (p<0.001): aging origin")
    lines.append("countries form FEWER bilateral portfolio connections. But the intensive margin")
    lines.append("(GLS on positive positions) shows dZ₁ = +0.815 (p=0.208): aging origin")
    lines.append("countries hold LARGER positions (though not significant).")
    lines.append("")
    lines.append("The signs are **opposite**, suggesting aging countries have fewer but")
    lines.append("potentially larger bilateral portfolio links -- a concentration effect.")
    lines.append("")
    lines.extend(all_output)
    return "\n".join(lines)


def main():
    print("=" * 70)
    print("EXTENSIVE MARGIN PUZZLE ANALYSIS")
    print("=" * 70)

    df = load_data()

    all_output = []

    # Section 1: Reporter-level correlations
    out1, ry, ry_pos = section1_reporter_level_correlations(df)
    all_output.extend(out1)

    # Section 2: Extensive margin with KAOPEN
    out2 = section2_extensive_kaopen(df)
    all_output.extend(out2)

    # Section 3: OECD vs non-OECD
    out3 = section3_oecd_split(df)
    all_output.extend(out3)

    # Section 4: Structural vs genuine zeros
    out4 = section4_zero_decomposition(df)
    all_output.extend(out4)

    # Section 5: Concentration analysis
    out5 = section5_concentration_analysis(df)
    all_output.extend(out5)

    # Build and save markdown
    md = build_summary(all_output)
    outfile = OUTPUT_DIR / "extensive_margin_analysis.md"
    with open(outfile, 'w') as f:
        f.write(md)
    print(f"\n  Saved: {outfile}")

    print("\n" + "=" * 70)
    print("ANALYSIS COMPLETE")
    print("=" * 70)


if __name__ == "__main__":
    main()
